#!/usr/bin/env python3

import sys
import os
import serial
import struct
import progressbar
import argparse
import json
import time

import settings

gk7205v500_socs = ["gk7205v500", "gk7205v510",
                   "gk7205v530", "xm7205v500", "xm7205v510", "xm7205v530"]


class bootdownload():
    '''
    Hisilicon boot downloader

    >>> downloader = bootdownload()
    >>> downloader.burn(filename)

    '''
    def init_bootmode(self):
        i = 0
        counter = 0
        while i < 30:
            in_bin = self.s.read(1)
            if in_bin == b'\x00':
                continue
            if in_bin == b'\x20':
                counter = counter + 1
            if counter == 5:
                self.s.flushOutput()
                self.s.write(settings.AA)
                self.s.flushInput()
                print("Welcome to boot-mode\n")
                return None
            i = i + 1
        sys.exit(1)

    def load_config(self, chiptype):
        max_nesting = 23
        while max_nesting > 0:
            max_nesting -= 1
            with open("profiles/" + chiptype + ".json", "r") as read_file:
                file_contents = read_file.read()
                file_lines = file_contents.split()
                if len(file_lines) == 1 and file_lines[0].endswith(".json"):
                    chiptype = file_lines[0][0:-5]
                    continue
                return json.loads(file_contents)
        return None

    def __init__(self, args):
        serialport = args.port
        self.chiptype = args.chip
        self.DEBUG = args.debug
        print("Trying open", serialport)
        try:
            self.s = serial.Serial(
                port=serialport,
                baudrate=115200,
                parity=serial.PARITY_NONE,
                stopbits=serial.STOPBITS_ONE,
                bytesize=serial.EIGHTBITS,
                timeout=None
                # timeout = 1
            )
            if self.chiptype in gk7205v500_socs:
                self.handshake_v500()
                return
            else:
                self.init_bootmode()

        except serial.serialutil.SerialException:
            # no serial connection
            self.s = None
            print("\nFailed to open serial!", serialport)
            sys.exit(2)

        data = self.load_config(self.chiptype)
        print(data)
        self.chip = data

    def __del__(self):
        print("Exiting...")
        if self.s is not None:
            self.s.close()

    def calc_crc(self, data, crc=0):
        for char in data:
            crc = ((crc << 8) | char) ^ settings.crctable[(crc >> 8) & 0xff]
        for i in range(0, 2):
            crc = ((crc << 8) | 0) ^ settings.crctable[(crc >> 8) & 0xff]
        return crc & 0xffff

    def append_crc(self, frm):
        crc = self.calc_crc(frm)
        frm += (crc >> 8 & 255).to_bytes(1, 'big')
        frm += (crc & 255).to_bytes(1, 'big')
        return frm

    def sendframe(self, data, loop):
        for i in range(1, loop):
            self.s.flushOutput()
            if self.DEBUG == 1:
                if len(data) > 20:
                    print(
                        "len: ",
                        len(data),
                        "write : [",
                        ''.join('%02x ' % i for i in data[:20]), "... ]"
                    )
                else:
                    print(
                        "len: ",
                        len(data),
                        "write : [",
                        ''.join('%02x ' % i for i in data), "]"
                    )
            self.s.write(data)
            self.s.flushInput()
            try:
                ack = self.s.read()
                if len(ack) == 1:
                    if self.DEBUG == 1:
                        print(
                            "ret ack   : ",
                            ''.join('0x%02x ' % i for i in ack)
                        )
                    if ack == settings.ack_b:
                        return None
            except:
                return None

        print('failed')

    def send_head_frame(self, length, address):
        print('Send HEAD frame...')
        self.s.timeout = 0.03

        frame = bytearray(14)
        frame[0] = settings.twos_complement_to_int(-2, 8)
        frame[1] = settings.twos_complement_to_int(0, 8)
        frame[2] = settings.twos_complement_to_int(-1, 8)
        frame[3] = settings.twos_complement_to_int(1, 8)

        frame[4] = (length >> 24) & 0xff
        frame[5] = (length >> 16) & 0xff
        frame[6] = (length >> 8) & 0xff
        frame[7] = (length) & 0xff
        frame[8] = (address >> 24) & 0xff
        frame[9] = (address >> 16) & 0xff
        frame[10] = (address >> 8) & 0xff
        frame[11] = (address) & 0xff
        crc = self.calc_crc(frame[:12])
        frame[12] = (crc >> 8) & 0xff
        frame[13] = crc & 0xff

        self.sendframe(frame, 16)

    def send_ddrstep(self):
        print('Send DDRSTEP frame...')
        seq = 1
        self.s.timeout = 0.15

        address0 = int(self.chip["ADDRESS"][0], 16)
        self.send_head_frame(64, address0)
        head = bytearray(3)
        head[0] = 0xDA  # b'\xDA'  # 0xDA
        head[1] = seq & 0xFF
        head[2] = (~seq) & 0xFF
        data = head + bytes(self.chip["DDRSTEP0"])
        crc = self.calc_crc(data)
        h = bytearray(2)
        h[0] = (crc >> 8) & 0xff
        h[1] = crc & 0xff
        data += h

        self.sendframe(data, 16)
        self.send_tail_frame(seq + 1)

    def send_tail_frame(self, seq):
        print('Send TAIL frame...')
        data = bytearray(3)
        data[0] = 0xED
        data[1] = seq & 0xFF
        data[2] = (~seq) & 0xFF

        crc = self.calc_crc(data)
        h = bytearray(2)
        h[0] = (crc >> 8) & 0xff
        h[1] = crc & 0xff
        data += h

        self.sendframe(data, 16)

    def send_data_frame(self, seq, data):
        self.s.timeout = 0.15
        head = bytearray(3)
        head[0] = 0xDA
        head[1] = seq & 0xFF
        head[2] = (~seq) & 0xFF
        data = head + data

        crc = self.calc_crc(data)
        h = bytearray(2)
        h[0] = (crc >> 8) & 0xff
        h[1] = crc & 0xff
        data += h

        self.sendframe(data, 32)

    def store_SPL(self, data):
        first_length = int(self.chip["FILELEN"][1], 16)
        address1 = int(self.chip["ADDRESS"][1], 16)
        self.send_head_frame(first_length, address1)

        bar = progressbar.ProgressBar(maxval=first_length, widgets=[
            'Send DATA frame',
            progressbar.Bar(left='[', marker='=', right=']'),
            progressbar.SimpleProgress(),
        ], term_width=80).start() if self.DEBUG == 0 else None

        seq = 1
        data = data[:first_length]
        for chunk in (
            data[_:_+settings.MAX_DATA_LEN] for _ in range(
                0, first_length, settings.MAX_DATA_LEN
            )
        ):
            self.send_data_frame(
                seq,
                chunk
            )
            bar.update(seq*len(chunk)) if self.DEBUG == 0 else None
            seq += 1
        bar.finish() if self.DEBUG == 0 else None
        self.send_tail_frame(seq)

    def store_Uboot(self, data):
        length = len(data)
        address2 = int(self.chip["ADDRESS"][2], 16)
        self.send_head_frame(length, address2)

        bar = progressbar.ProgressBar(maxval=length, widgets=[
            'Send DATA frame',
            progressbar.Bar(left='[', marker='=', right=']'),
            progressbar.SimpleProgress(),
        ], term_width=80).start() if self.DEBUG == 0 else None
        seq = 1
        for chunk in (
            data[_:_+settings.MAX_DATA_LEN] for _ in range(
                0, length, settings.MAX_DATA_LEN
            )
        ):
            self.send_data_frame(
                seq,
                chunk
            )
            bar.update(seq*len(chunk)) if self.DEBUG == 0 else None
            seq += 1
        bar.finish() if self.DEBUG == 0 else None
        self.send_tail_frame(seq)

    def sendFrmWaitAck(self, wdata, timeout):
        ret = self.s.write(wdata)
        if not ret:
            return False
        tryX = 0
        start_time = time.time()
        while time.time() - start_time < timeout and tryX < 10:
            rdata = self.s.read()
            if rdata is None:
                return False
            if rdata == b'\xaa':
                # print('ACK')
                return True
            if rdata == b'U':
                # print('bootrom NAK. Retransmission...')
                ret = self.s.write(wdata)
                if not ret:
                    return False
                tryX += 1
                continue
        return False

    def sendDataToBootrom_v500(self, data, address, title):
        idx = 0
        pos = 0
        length = len(data)
        count = (length + 1023) // 1024

        frm0 = b'\xfe\x00\xff\x01'
        frm0 += struct.pack('BBBB', length >> 24 & 255,
                            length >> 16 & 255, length >> 8 & 255, length & 255)
        frm0 += struct.pack('BBBB', address >> 24 & 255, address >>
                            16 & 255, address >> 8 & 255, address & 255)
        frm0 = self.append_crc(frm0)
        ret = self.sendFrmWaitAck(frm0, 4)
        if ret != True:
            return ret

        bar = progressbar.ProgressBar(maxval=length, widgets=[
            title,
            progressbar.Bar(left='[', marker='=', right=']'),
            progressbar.Percentage(),
        ], term_width=80).start() if self.DEBUG == 0 else None

        while length > 0:
            idx += 1
            numOfSend = 1024 if length > 1024 else length
            length -= numOfSend
            splitx = data[pos:pos + numOfSend]
            pos += numOfSend
            frm1 = b'\xda'
            frm1 += struct.pack('B', idx & 255)
            frm1 += struct.pack('B', ~idx & 255)
            frm1 += splitx
            frm1 = self.append_crc(frm1)

            bar.update(numOfSend*idx) if self.DEBUG == 0 else None
            ret = self.sendFrmWaitAck(frm1, 4)
            if ret != True:
                return ret
        count += 1
        frm2 = b'\xed'
        frm2 += struct.pack('B', count & 255)
        frm2 += struct.pack('B', ~count & 255)
        frm2 = self.append_crc(frm2)
        ret = self.sendFrmWaitAck(frm2, 4)
        if ret != True:
            return ret
        bar.finish() if self.DEBUG == 0 else None
        return True

    def handshake_v500(self):
        """Send bootrom handshake until a valid response is received."""
        frmRemoteBurnBoot = self.append_crc(
            b'\xbd\x00\xff\x01\x00\x00\x00\x00\x00\x00\x00\x00')
        try:
            start_time = time.time()
            self.s.timeout = 0
            while time.time() - start_time < 20:
                self.s.write(frmRemoteBurnBoot)
                response = self.s.read(14)
                if response.startswith(b'\xbd\x00'):
                    chipid, = struct.unpack('>I', response[8:12])
                    print("Detected SoC:", hex(chipid))
                    break
        except:
            print(f"Handshake error")
            sys.exit(1)

    def send_data(self, data):
        if self.chiptype in gk7205v500_socs:
            bootLoadAddr = 0x41000000
            headOffset = 0
            headAreaLen = 8192
            auxcodeOffset = 8192
            auxcodeAreaLen, = struct.unpack('I', data[1024:1028])
            headAreaData = data[headOffset:headOffset + headAreaLen]
            headAreaAddr = headOffset
            auxcodeAreaData = data[auxcodeOffset:auxcodeOffset + auxcodeAreaLen]
            auxcodeAreaAddr = auxcodeOffset
            self.sendDataToBootrom_v500(
                headAreaData, headAreaAddr, "Sending HEAD")
            self.sendDataToBootrom_v500(
                auxcodeAreaData, auxcodeAreaAddr, "Sending  AUX")
            time.sleep(0.1)
            self.sendDataToBootrom_v500(data, bootLoadAddr, "Sending BOOT")
        else:
            self.send_ddrstep()
            self.store_SPL(data)
            self.store_Uboot(data)

    def burn(self, filename, term):
        f = open(filename, "rb")
        data = f.read()
        f.close()

        print('Sending', filename, '...')
        self.send_data(data)
        print('Done\n')

        if term:
            print('Sending Ctrl-C\n')
            for i in range(1, 50):
                self.s.write(b'\x03')
                pkt = self.s.read(1)


def available_chips():
    result = []
    for file in os.listdir("profiles"):
        if file.endswith(".json"):
            result.append(file[:-5])
    result.extend(gk7205v500_socs)
    return result


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("-c", "--chip", required=True,
                        choices=available_chips(), help="Chip model name")
    parser.add_argument("-f", "--file", required=True,
                        help="U-Boot binary file to upload")
    parser.add_argument("-p", "--port", default="/dev/ttyUSB0",
                        help="Serial port device name")
    parser.add_argument("-b", "--break", dest='term', action='store_true',
                        help="Send Ctrl-C to prevent autoload")
    parser.add_argument("-d", "--debug", action='store_true',
                        help="Set debug mode")
    args = parser.parse_args()

    downloader = bootdownload(args)
    downloader.burn(args.file, args.term)


if __name__ == "__main__":
    main()
